{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "from torchvision.datasets import MNIST\n",
    "from torchvision.transforms import Compose, ToTensor, Normalize\n",
    "import torch\n",
    "\n",
    "device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')\n",
    "\n",
    "BATCH_SIZE = 128 \n",
    "TEST_BATCH_SIZE = 1000\n",
    "# 准备数据集\n",
    "def get_dataloader(train=True, batch_size=BATCH_SIZE):\n",
    "    trransform_fn = Compose([\n",
    "        ToTensor(),\n",
    "        Normalize(mean=(0.1307,), std=(0.3081,))  # mean和std的形状 和 通道数相同\n",
    "    ])\n",
    "\n",
    "    dataset = MNIST(root='./data', train=train, transform=trransform_fn)\n",
    "    data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=2)\n",
    "    return data_loader\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 激活函数的使用\n",
    "```import torch.nn.functional as F\n",
    "  F.relu(x)```可对x进行处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0, 0, 0, 1, 2])"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "b = torch.tensor([-2, -2, 0, 1, 2])\n",
    "F.relu(b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型中数据的形状\n",
    "- 1.原始数据的形状[batch_size, 1, 28, 28]\n",
    "- 2.进行形状的修改：[batch_size, 1\\*28\\*28]  全连接层是进行矩阵的乘法操作\n",
    "- 3.第一个全连接层的输出形状：[batch_size, 28], 这里的28是个人设定的，你也可以设置为别的\n",
    "- 4.激活函数不会修改数据的形状\n",
    "- 5.第二个全连接层的输出形状：[batch_size, 10],因为手写数字识别有10个类别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "# 2、构建模型\n",
    "class MnistNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(MnistNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(1*28*28, 28)  # 输入有1*28*28个特征，输出28，这里不用管BATCH_SIZE\n",
    "        self.fc2 = nn.Linear(28, 10)  # 因为最后有10个类别的数据\n",
    "        \n",
    "    def forward(self, x):  # 因为输入的x是一个[batch_size, 1, 28, 28]的形状-->转化为[batch_size, 1*28*28]\n",
    "        x = x.view(-1, 1*28*28)  # [batch_size, 1*28*28]这里计算不考虑batch_size\n",
    "        x = self.fc1(x)  # [batch_size, 28]\n",
    "        x = F.relu(x)  # [batch_size, 28]\n",
    "        out = self.fc2(x)  # [batch_size, 10]\n",
    "        # 这里面还需要进行softmax\n",
    "        return F.log_softmax(out, dim=-1)"
   ]
  },
  {
   "attachments": {
    "image.png": {
     "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXEAAABMCAYAAACbFIdjAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsMAAA7DAcdvqGQAACL7SURBVHhe7Z0HYFVF2vf/6T0QOkgxofem9CaCgl2wg/WzIrquu+vu932i6/qKFRF72XUVK6iAqDRRUKkBaaEk9ECoCem93ff5P/ec681NAgSSwIX5weS0OXPmTHnmmWfmzPVxCDAYDAaDV+JrbQ0Gg8HghRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGw3mKo9ShzuDdGCFuMJxnFBcWo7SkFPn5+fDx9bHOGrwVI8QNhvMEW+s+fOQIHn3kUcyZMwd79yTqOYP3YoS4wXCeQK3bUQrk5ubi6aefRsOGDcU10Gvbtm5T7dzgfRghXptUl/nRmDFPDpNOSmZGFvJz83XfR2p8aGgoduzcgY4dO6K0tBSZ6ZkoKirmVfVj8C6MEK8tqkOgGKF08rjSytrhpiJ3jrPst2XYuGGDCu1flv6KhPgExMdvw549e7B61Wrk5+WjpLQEc+d+C18/S4ifB+lyLmGEeG0idaSosAgFeQW6LRbtp6S4BIUFRXpMxy5tQX6h+qENs6S4VP2ZilVFKI9UJh0n4c5xxXP1qljMnj0brdu0RkREJA4dOoSAgAD069cPIy4dgWGXDENYeBgSExPRqVMn6y6Dt+HjEKx9Qw1TKgJ58vOTsX79egQFBqHVhRfCRwSJuy0yMysTaWlpyMnJBbPGz9dXNCQ/5Ofn4auvvpZKF2r5NJwcTFvRVTxL+TkswGn35uu9//77WPzTYvTs0RMXRl+IwYMHo3mL5k5PFnk5eXht2mu47bbbpDy2ss4avAkjxIn7eA5Lfw1W8HW/r8MTTzwhj/DBs//zLC6+uI8IazsCPvAXgU2ZQw09Oztbur17sWzZb1i9ejWGDh2Ke++99/jxY26ewwLqpLBLtKaDQ/7JjpxzJYsP7b/shFod0aqk2fH8niVpX1JUIlFx4NVXX5XydTEuGX6Jxi0rMwt+Ur5CqQhYcZ33/TwEBgZixGUjnDefTVQ1X2wquofX3cM7Xri2X+IFtgq/fwrW/rmLlSHaXvE/P3KQfc/2i4K1XOYeL7NPgSaNm6BJkyb4/ofvsXvXbnTp2hWNGjUSjdxHhTnj5Osn2rdo4MGBwWjSrAn69OmDHqJN7du3Dy1btpLzQX/Ele/D1xCnh7ItlQaAA1b2+xYVFaEgv0D3fX39VPuvEnIfew9Nmza1TsgpSUOaeSSm+pziYtn3kbDl2o4dO/RaeHi40/MZQKKsCVIqf+yOjk6J1ncvFCeNpfPAiduu82bB/RzheXF8V+aX9qDk2L0c8fwZw4qGr4+v2rkPHjiIzMxMNG7UGNu2bUPc5s1o2KihM18YTfEfHR2NmNYxzvnivL+Wos9eKdNR00/KEss705P7mrZSfnmtlNeYphXEi6ZGXzve7lTgl1mkp+1rsmUcyuSXvcs42YHyPtsPT7l5d1HZ+Vri/NDE7fyQV42L24zpH3+sic7CfPvttyMmOsaVCSocbbhbQ5kz7bVp+OSTT3DjjTfiscceU/MKW31HiRRie4DJDRbu/fv3o06dOoiMiNR4/fzzz1j842KtDK1atcLERyaipKQEX3zxBWJjY1Xo3vt/7sWChQvwyy+/4PE/P442bdsgMCjQCvUkkLTbuGEj9iftx1VXX2WdBN5+623s3r1b48045eTkqCB48MEHNV3/+9+PcP999yEoRN6rtrEqFTcp6dnYtnU7Bg7oBT8tByXIzzqCLVt3oX3HngivU76h2bRxkwqRHj17WGfckDCo4aYkpyA8IhwB/gHIy8+TxrUlxlw/BpF1JW/OJHxHcdykpaWqIC/Vnp6P5FMkLmh2AfwD/emzPLyphsq7O0n7D+Czzz7F0SNHERUVhRukDsTERCP1WCpeePFFBEn5vOWWW9CzZ88/4iTu448+xp133alhnDRMC4o4+a/C2AqLLF+2XHsqgYEBvFj2/SXJWAaKS4q13rVqdaGEU4oA+iVWuN9//z369x+ABg3qq+LC4+vGXOf0U0ucX+YUvqm4rOws/OlPj+GBB+5H3359NcOcl525aOeje4ZXNxnpGXhM4rBi5Qp88P4H6NW7F0JDw5yaRRWeOeubWZg5cybeeecd1K1TVwpdCb755hsRpGG48sqrtGxSGxs+/FKsWL4c/gH+qumfLNlZ2fj3v/+tNtNGjRtZZ51d9j8//md0aN8BEyZO0HR96623dMbDO+++g507d2LVqlV44MEHrDtqCas0F0sFTNy3Hxs3b8HmzfG4YpSkRXEOApGPtKN7sS5uFzr3HIhBQwYiJKRso7Zw/iLkF+Th2uuutc6UZcP6DfjPf/6DcePGoV//fsjMyMSrU14Vjba1KAXjz7qvIO2B8UqF9xmAg/njpEwNv3S4KhpMM5a1xYsXo0/fPmjZqiX27N6DlJQU1dJpBnrppZfwZ1FEbCKkEaUClnQgCS+++BLSUlOtK38QHBKMW2+5FSNHjsTrb7yONWvWWFeAw4cOo25UHQQHB4ryQ43fDw9NmCANXTMcPXpUzUxZ2dmYPXsWBg4chJYtWsDP3x+NGjZEU/FTVFSIyZMn4z5RVtLTMrTocRyCx4EBgdLDboioevVEaarZdPcCi081IqnMNosFgwOFbdu0tS78Qbnqx5yx3cnC57BhqOg+6xy11YcnPoxOnTrj+Reex4GkA39M8aoC3bt31y4nTRhFopFvjouT7nMjpwCX3KVmvpnd6AYNECiF1RbgRVKJOHfY8wMPml1yc3KtIxFYGzZo4xAeHmGdcZKRmaHmoPbt22u3lOF07dJV4lCE7OwcdJH97777Djmyr9jpUFGaVDPUoPzlPcNDQ1E3MlIFQIlkSKFUOlFbRHuW7rpUUJ/IKPiI1pedk4VN62Kxds0qrF27BkdEQ0yUnseadWuwZd3vWBe7Bjk0UVnhBwcH65hFg/oNRECWaPgRkRE4dOigS2Cy51QpdhrYrrrxCJMNNwW4lsna4jjvxrISIHG6+pqrER+fgMysLOTm5mH+ggXo0rWLCnDe6+frhwAR4D6lxdLbLIFDCrSPD00gko0icP39QrRM0xz03vvvYubXM53uK3Gi2HA7ffp0jL5ytL7/4395HG++OQ1/+9sT+Mc/nsSYsTdoL/gvf31Mtn/Dp599iXbtYvDbil+RU5CPdCm70S2a4Il7b0R0q2hENY6GT2EhFkk8cwpLcfhwqpT/LggMDUdeQaGUCYfUnTx5P6c66HD41LgAJ+eHTdyCApwtPs0NFHyjrxitFdAubNTBc3JzdE2JgsICFBQ4Xb4ItgKek33aqY+rKdth0Y/tz9O/+KFwbdKkqWreP/74I+ITErRrFxERIfdW8gCG7X5JjutJS7906VLRsldgwMABavvs178/QkIlnlalnTFjhmgSA9Glcxd9f2qS9Jedk61bCiNOPeMUtCVLlsj2MLbEbZYuYgNs3bJVG5wOHdqX0eB3bN+hFeUvf/0rQsNC9d6pr03FsKHD9Fl8hbVr12r8aGpwxdtKH6WS1zwtJExbE2ZvhKamrVvj0bpdO6lkeQjwK0FRXioOFgAdhoxE3VB/HEjcjYU/zBHtvUQapyzs2pmIjLxsBIUHo+BQsvRsZqHj4P6IDAlTS/qmTZtU+7vs8svgL5pZYuI+LPn5Z4waNUoEUCt9PoU8n80GrSBPyo6WIadj+WKZYtnz82eI1Uwl6VpZsaoJtK6ptLVOuMH04bUSEebz58+TBr+LluH27duhe7fuznvopDFOSxGN2M+BUv9gLF+5EgP695YLLIf+aNokRh5UgkU/fi3lfyU2boxT09+WLdtkfxM2bdwox9ITi9uJ5s0bIUiqxKpVyzHllWnSY9wgQvgQVqz4BcuX/4xfl27A1VeORZ0oB6Lq10fjC1phk4QT6ZOL1gHHsDQuGRdfOgAx0lg3kHob3qQR/vPudInPQGQU5CA3Oxf16taTXl0I6kt9SktLR0xMjFOIe9bbaua8EuK2oeS1qa+hZ88e2m1T3ATL3LlzsVIKCyuq7eLi6DaLi5NuVFNE1rHsntZ9LLDa+kqhIwXSKlP7yc6SCiyVNSg4SCs1u4WKRINaBqcPtm3XFtsTtms3j40EBbnL7nYirMIe4B+ID//7IUKCQ9RuHVUvynWdWs+XIsRvvvkm1KtfD+t+X485c77FyBEjECbC+Zelv4gw8dfpZdRKrrjiClwo++y6DhkyBHsT9yJCtPB2IgSt5FNostm1a5dqSAsXLtCBz/Hjx+sMmiDRcpkmNKmkpqahd6/eqh2rcLXT2op7TZMqXewt8dsQ07atCvFAEeLFuek4mFGI6I69US/EDzkZqSooLht9hQiRi5AqQrdxowYYO3oU2nboLA1aHC7s1w8NJH0pcr/++mttxI4cOaI20AMHDuL6MdejV89eTqVA3otCKnFvIuaJkNqwcQPiNse5HHtGm2VLPw0aNJSGoKwgpzbPLyg5QM1yczKO91R0vjYdGy3Go1C0VZZlKgZ2g1oRHE9ZI70cKg7jbhuH7j26a0+OtcRH6sbe3buRWyTpERCBIN8SrP5tqSgoQ+UqC5E8UxShOnXCxUWgadML0KJFKxHWLbBzx060EMWhQ4cOaC1aevPmzbTs+/r4IynpECIj62D06MvR5+J+6H1RT1E6euLo0TRcNnKkjhcVS49t+bJf0ffiHijMykD64WTszQ9GtPQSwqUM1G1YD4m792H6ZzNxyfARSD2chDnS0FM5OnT4sNTltcgUZYCmNpcmXoNl/fyxictbMtNZ8VhgXpnyilY6qifu6esalfbAmUpO7cJ1Bzfiprw8RQdAGD6nBVLTYmHmGhWXj7ocN9xwQ/nCzPDkFE0RaelpuH387ar9T5o0ST/CqBTeJ041Hcv8QlPM6FGj8dFHH+lgnMbQukaTx9SpU/HCCy+ojW/atGlo1qwZbht3m35Q9M9n/qkNBzVyCrznnnsOCdIrmPzcZHzwwQe6SBIbrpGXjXTFmVx5xZU6KHvXXXdphWVDpaYFq6Fi/GbMnIG9e/fqlEpbuLlsEty3wqpJ+KXiV7PnYMSoq0QDz0C4fzHyUvZhTWImho25G9H1g3BgVzwO7d2B3v0HocQvGEtXrUVR2jHcePkIoLAEU996GyMmPowukZEa5fHjbteBt0svvRSBIqj8ApxC2DZN2T0WCjbXe3siackyUpGdmmlbVcqZb+xn2nnmcbkmoQAePvwSTJw4UctFZRSKsvPyyy/r7Kx77rnnj14J30XqmY7hBIfBNzAE3838RIT6Tjzy2BMSPs13JSKMmyCmdUtRekqwI2GP9PzW4dZbb8bc776T3l9zKYPFmD17Lv71zLNaP6lYLfvtVxHkBzFgwGDJHz95VC6CQwrw9lsf4//933/qIGdAoC+SU47gs8+/xE1XXI7sfduxNsWBoVdfh6aSXSziPyxaBN+gCERF1MGxA4kSdgAuGzUC+QUlWv/ZC+0rSmJkZIgz/SsqA9XEeSHEOYBBYdKrVy/Vplm4OIASGhLqkbiiPSUmSibkWcduSCqxcnJ6IO3DTmF+krBO25XJhvXdPi+8/vrrWqkffOhBNU9Uivh3zzI2DgvnL8Tf//F3rF+3vtzMFl5bvmK5asns5r304osYO3YsevTogaPJRzFlyhTceeedeOONNzBk8BDcefed+OC9D9RcMGnSU/jii8/RuHETXHb5SBTmF6oNMml/ks5XZ+PQsVNH60kSL1uQyIYzIj7+eDpycrLxyKOPOM/XFPI8/RJWGkHOGOEMnxLR/A8c3C9CfCuWrYrFqKvGIjTAF36ijecfO4CFa7ehQ99L0K9rW+RnHsOnn3yEtNxiUQ/DsP/IUYQV5KFL44YIlV5OsuTL+KeeQrOQYBRKN/maa65RW2vTps1ESHnM9LHzVaAQZyPKsQpPONOBMzMowChcTht5LlO/hMLK3+rCC3x2gDQUnI7Ha7UJ6wh7lVybhfniOaDOevn888/j2muu1QFOLdsacR+USBlno5iRfBDffvUZSsMa45oxY7F4/jdo26YTVq1cj2uuvhnRbaJw7NhevP3ml3jkkb9J3fQTwT0TMW0uQPfunaSnGYvgwEYYOLgb9iftUNPlnNk/Sq+1gQhjhwj3HEn/QuRkF+Kpp55Fg0ZNkZ6dgeycXISHNIVvfhq2LZ+B0nrNENO5DwJLAuErDcsFraPx649z0FLKQGpyrpqF6tSN1OwvLCrS2Sp3332XnuNrc3zq5AVG1Ti3zSnO8qBT8ZZLq86ZFJxpQQ2Knx7zmlOvtv8xsX3UrhUWFlbGhYeFq/mBdlbOw6Vn5x0nge3NzTu1dppXOIeb0/9oV73//vs1fA2Xfq2K6H6fDSsIr2dkZOiskPbybheJRh0aKi2/G7///rv6pblm2LChqnHzPbh63ZKfl6gZZdiwYWoaYeMQVTcK0z+ZjptuugktW7QU//GqnbRt21bjwWlgixYuQsL2BB3x50p4NpqGbNwkDSnAZs36RhtOTmvU0l3Be1QLEu7ixT/hu7lztVehzxeBlZAQjyNHD8qxP9IyspCWkiruGA4nJSE4Igq50iiFi9bVtk1r9OnbTzSpqzBU0igsJAwdolvh3gcfRJ+hgzFoyBApE8HIOJqMmV98icPSZabZqW7dujqmUQ77lLwzNfEwaZSZr2UdzQB1tHdUHenCvGP6c7oje4C5eblavrJzcnT2DMdaKCA1f2rJsZ5wxgmnZLKhatWyletdOXhO88Oy5cvQt09fHTeRu/Sa8018sHnLZvy8ZAlWrl6DCRMfQUR4KPaLkhUbu0Z6K1ehXfvmyMsrxvwF32p9jm7dXLVk9iTr12uIZhfEICQoHLFr1qFR4yipY3txLCUN3br1Utt7l67tRZnphk6duqNVqxhkZWWhYaP6aFi/PiKkvgf5h2CnhBUbuwo+wcGo16gxWl7YRfOtToQ/4ndsRT0pA6WZ+aKB56Jth3Y6g4uNM01k3bt3c5pN/e03E6ohrz05t4W4lWBMdA6icPCtbbt2uP766/XLSLekdcGZB9SEudKbuwsR4ciPbLTLJ8KThbTKuN1i28c5V3XVylUYM2aMTm3ScG1/FOJOb2VQ04z+91GNmQKBNn7aIGn7c6dOZF0VJBRuHKhkZcnOzlIth+tm9O/HQdAQnalDTZoDu1s2b8G48eM0HQ4fPCQNRabO4+WcaC4JwKUBunXrpmnFeceu+HJr7XMK4ocffqg9C/qr6D2qkzwRXFzUiZ+X142qq9oo7aKsXNu2JWCUCOhBA/qLIIlGsXTjt+3YheEjRqBrp/bwl7wICAgSmeujNu8E8e9bXISObVszozQb/ES7T5KGlsKIM4I4YFtfKrvr3d2xzjEvmbaeZcl2DEuF74nK0nEacxv9vkDCidsUp9MfmZ916tbB559/ruMR0aI5sszVttMiLKooe77s6bHc+YhwT0lJRrr0alq3bqONIXu4Wq7lBs7SKijIx9EjR9ChS3ekZBXgxx++haOoAN26X4SOHTtImUzF++99LMpLG+kRNUbrNq1EYJZIGH6iuKwXId4MTZs0lgY0HIFBAaoZt23XET17XYQLLmguvYCDuPb6q0QOSJlOzcMVo0fjot490aBhfdHAg7BE4vvTj4t12mjfAUPRuEUz5OQX4Muv52PooP4ID/XF73EbERoUim/+Ox0bt8Zhe+JubIvfhm1bt6oZL13qSt9+fSQdJBGYdyfI5lNGCtH5QanlTgcrjNLi0w1IkP7ijoQdjucnP+/ISMtwnnMPVq6rk3O5OXl6SpFzouXqbnFRibhi57Gb/3J4nuOxfU62SfuSHKKVy3NyHb8u/dUx5ZUpjrTUNL2cn5vvmPrqVId0fR2lEr6+u32/HUYF7Evc73hq0lNOP87o1jhvvvGmo7CgUPdzsnLl+U87Fi6Yr8cFEodCuVRSJAcSJ77GG++975jz/Vy9XiwneYlRXfrrCsdSSQ9HcbGcoJOzvGC/s+1OFwlDhJp14IY8i/nKd5HG0HmO3rhrHXrCcHjPnNlzHJOenKR5+dXMrx0J8QmWjzMDy+iunbsc7779rr4Dj51xdZbdPxyPnY7vXKqu2JGbneeY+908qSPpGl6JZBzLuzrJMB6XMp+EhQvmOa4cPdrx8ISHpfxmSFiSp4XiX/zl5xU4Uo4mOyY/N9lx1513ObbEbdF7yA/fzXfcMPYmx7SprztiV65yPP7YY3I9zpnmHiQdSHZ88smXjnSpH7Fr1jp2bN8hBcvNo+w664lzW1l+VSdm7ZSqYKeUqhfO3VMlcU+izvAYNHiQzpLhT2ZRY3G3G3Jerzbe1imuSlevXpR+PcZ5tqRcb+IUWvwlPy/FjBlf4qGHJmDlihUaJ85ft+et09bOwRpq7coJ3l0aANXC77jjDv3MuzbgrJv09DSXbZVp4CihbV4OREOlNSfA1ogseM5+Fd5iZ6umu/grLSiAr/RunMgV2zPDoGfiFl6VsO8nHmFIxdfBbo7jULvmLA61FVv3sMpyoI6zQNjL4fx8joVwJhQHCvlh165dO3Xco2u3rs6bqsLpvpuNhMNewKxZs3DdddL7DfDTgfwy6VcJZR5taeh2D7RCmJm85nld7mN+OrSf5YPUtFTUF62c/vndQEBAoITr9Lo/MUk0+CjtgarpSVXosoiQ1ng4exQSIgsK4YZxqIjywVQrRohXBfeUqqwwnSSvvvKqrokyaMggLRgaILvWLKjEKhTOwuSDpSJoY9fE4v777ldzgVJRzvG+U4ibaOP6lRorPQejOFDIqYcU5Iwfp4K5pmQeL3yJU0LCdrX1d+na2TopWK9YU9D2ShOF3QhSgCvyTK1nbs9mpeZxEdfOEO/8MIinrLkRTtwrpJ3OdmWswfcgnLURvy0ej/7pT7joot76+TkFtPuccs4I4iA850NzGYl6DeppGjz5/5/EhAkTdFZR+w7tce2111Y9vvb7Vsd7SjpykNuVL0zXKoocNsTMQ3cFxxOa7+xZQq6yJlvno5zPU4FrvZNo/05zllX3WA7YePJZHOegWNSPxjhI7J4OjD+jYT+DOIN34r5P6Mf9/hrACPFaJi83T0fkudYJlwblglQsNLQTSgdMBzz5yS4130OHD2Hvnr1Yt36dzrmmLf++++/TiqCCqaIC4749DTjvXcs8GxV5Dgv08SqRC/HL+cLlCn9tI/FwpZMdbbf04q4oU5C+KPxYaeXYJSJ5QMf7WGltGFYtvBO1PH5r8Omnn+oKlq+9Ng11IiMlSvYLsLxAp7MuXLBQG93rrrtOv9qdPWu2zvfnoOCCBQvw8MMTEckPyDxmLVWK9e7c2BqqLbiOpRzTOd0crGZ50AFrC1sonhRMd2v3RGis5Y8uhHWcd6AAdr2jFd8yeSfYA7suuEs/3Nqn3SJW0TO1TDFsTSA9VeaeCikbRLVjhHhtIinN+doceOJ64Zy1wM+NaRrRGS8Cp4RxJgS14NzcHKdQl+50auox7EvcV3aBJY9CWqYwnins0nSm42HjXtmIW/y4ax8SW2YpduV2p7bfSeJw27hxOjD9yfTpEleJrRVh27TAX+b57bffRFvvhVenTNHFml56+SVd8XLSk5Mw/vbxuOPOO8qnw/EQv2xIfP19tQdAwcePYJKPJKti8dNPP+HRRx/VngC/2uVMJk4a4MAyn3dCPMvtibDjfbz4V/R+FT2nTCYLVnqWu/dE6VWV9KxhjBCvTSrIeC4EVFxcZB1JuWMXTgQ4F+5x90tbIiuue5dQw3PPPZ73CN9QOZ4Fv0zSVVQraitt7XIiQujAgQO4+Zabcc/d9+Cuu+9S8wCXZrWFOKEgD+Y35SKgOLWTZYS9Jppl2CPivkuDtLGfUQnJR5PVJs8ZT7TJd+zcUXuRXN6h90W9tWE5dPCQTlnlB2b79ydh0OCBWk4p/I9LRWnrSW2l9TnACVLbUNNwDrZOabQcP4oI5ronHoVY65xUznJtLv3ZzlAl3JPurEw+qZ38uvaZZ57BokWLELs6Fn4VCEg2+LaphXZz9upo66WgtU1gbPsVq/hoMbKdBxT+H3zwbyQlJekUPS4kxTESLtGQnJyMFStW6DxvDr6mHEvBgvkLVIAzrBMKcOKZ8BU5w0ljhHhtchqFUyuv3K8amI1d4N2doSzuwsrdnQqnel9VcctH2nkvHX6pfoQ0e85snfmjuL+HbNmDcw7SCVatphB3YYcpW5pKSktLdOsS5m4sW7YcP/20WG3uO3Zs1y9UGT5XC+RgaUxMa/3mgD1GLl6Vl5eHPbv3mvJ3hjBC/ExjCn7NwvStyJ2IivyczH01gdRSrmfNj2X4NSAbcs7+cSHxoga+atVqPPXkU7rMMD9uqyy+/AiHXyfu278PGRnp5fzFxq7GiBEjcOMNN+rXzT179YS/fwDatGmj+zHR0cjJzdX1Qbh0BReuotCnWcdQ+xghbjBUBoWbu6st3LRjasu0Pa9ZE6tz4HnMuAQEeax0Kf65lELr1q31p9k406kMHto2V2Gc98M8HQj1hMtTcIYUodnkm6+/UTOerelz/Rx+JfzAQw+gc5fOzt9+ve9ep13eUOsYIW4wnIVwVhKneXJgc8aXM3S+d3BwiGrRhFPqbLbEbdEtBTI/ZNm6dav+bB/Xudkev93pErYjIT4B+/clqWmkW/duaNK0idOc4sGIkSP0F2n483vx8fFo16692tTdp5jaU+94znXeSJMzgkl2g+FsQ+QjF07KzMjQxdFuve1WREZyaqklcbkRqUpBv+y3ZbpGOe+hxr569WrVkvmTfEGiGXN9HNtx9U3+fqU2BFZQrgFPN8JCw3RlSy5BTPNJp44d1VyjM6MMZx1GiBsMZwMUquJosuASDFyJkIsp8YMwCnCaM3xsm474S09Px7vvvoenn34aA/oP0J/BSxehz9kk/G3TsTeM1VX7Vq+OdblYEfDr128QTVwCcJfHlkC34WAqtXX+uAhnvri+FDUy/KzEzBM3GM4mKMjF8Zd/uA4KP6X3VJd5yJknXF62cePG+Ne/nhUt/KB+Tck52zNnzETnLl3015zcTSBEP+KRc4cPHsa8efP012/4Y+Hun/QbvAsjxA2Gs5SM9EwdoPzDjOHccpkGmkR4nss18PdU+VOAx46l6Prwa2PXIjomBvUblF2W2Nb0ObuFPx3Hed5cO19t3qZP7rUYIW4wnG2wRlJec4IJt7YMr4T8vAIEh3j8DBrv9RTMlhBneM4Px2RrhLfXY4S4weDllPukntgNgSfu5yvzY/AqTDtsMHg5VdKmjdA+5zBC3GA4FzkZYW0E+jmBEeIGg8HgxRghbjAYDF6MEeIGg8HgxRghbjAYDF6MEeIGg8HgxXivED/d2e28vyph2P6r+tyq+q9uXPGWP659cQp3PJYsrWnKxeFEVOCxSvcbDOc23ivET2d6lAoAWxJQiJXKnvOfYl9yd6dDuXCKxTmf+4erjgd5UCY4OSgXPM/xpOvNa5dKH1oizhkvZ9pYuwaDoRxeIcTL1F+rbp9yndYb7butwFyOWELV/ZQNG45TaTzKPILh8xdQ3AWV62INwbA9I243HG7UdDRcuD2kwucViOMv17gJ8IriVuG9BsP5hW+FlcOmqudJVSrWSfql+GF1pthT5D4fq367nzsuvE7nEsQMRJz+AopTC3fuSbvGT+DYvNHficKtCvpc/gmQcPVAsCPmfHp1PpCPKOZjXM91bool8dj3OKnnnUp0jnuPXJR0dzAOzEQ7I3mPnQTginqBss+tU8/Q5LJewR16r/BxFZ40GM41gP8FY3T+rannvvkAAAAASUVORK5CYII="
    }
   },
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 损失函数\n",
    "![image.png](attachment:image.png)\n",
    "在pytorch中有两种方法实现交叉熵损失函数\n",
    "- 1、\n",
    "```criterion = nn.CrossEntropyLoss()\n",
    "  loss = criterion(input, output)```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- 2、先对输出值取softmax和取对数,再使用torch中的带权损失<br/>\n",
    "```output = F.log_softmax(x, dim=-1)\n",
    "loss = F.nll_loss(output, target)```"
   ]
  },
  {
   "attachments": {
    "image.png": {
     "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZcAAAAdCAYAAABmOiEHAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAAAJcEhZcwAADsMAAA7DAcdvqGQAAC/uSURBVHhe7XwHYJblufaVnZBNEjYhJICA7D1kCTJUFNTjtnW1VlutWmddp9XWarWec1pH7e8EFBlVZMtesgl7hA1ZQAYhe//XdX/fG77EhK2e9uSCj+/93vE893M/936eF69KAt8HLrZVL/d3PepRj3rU418O3u7vc4PjMM7kOHTN874z3Xsm1PXcxbR5NtRs+/vqpx7fD2qTjTPNoXN/bc+dCedz79lQW9/nS8+lxsX0/69E98WM8WKfPdvzF9OH4NnPxbRzETizc/EkzIPAsyY7znPOpy5caHbyPWU1VeM6F9r/HfBvNr6Kigr3UQ3UNU7n/I8413Xq0o9Ej+FMfZ+Nrh+S7lr6Ou9CzAWOtVo/59Gl89zZ6DzvcdQGNcGPteXZ3CVo+lxwZudCI15WWlZFjDNgrzqse3lZufvIjUoqu6PwekSH+tbHGaDz2/nUgtKSUvcRmyivw4DURC0MzMvNI0nOYNwfD1Rdq4OOaqil/XNBeWkNHv1Y8KBf81Y19hqo4vcFjvcHh+fckWbJrMkwUVxYbN8GXqugfMohVQ1NB2cYZ2W5+6K7D7XrzGd+XoF914k62jVdOhd5u0CUFJ/Wne/AGa8HbdlZ2cYXB7XKRR1jMTm6FEbxLCgqLKJcVtZKW03bdMbxC2fifW3X2KXGaH07bDqP+XP4U8HnNYZawdMaR106aTJXhxmsRpcb5dThavMievWzZvPOuZrnLxDfXXNxflUjgIwg0V7e3na6oqIM3j6+1ZgqRfP14znP1kzYyuHl6+/6LVemgeu5Mh741eHb1IZH29+B2rBH3Td63s/j8vJy+Pj6uE/UDjmaBg0awJtjsmdNYHigv26y8vPyERwSbMcnjp9ATKMYO74YSKAKCwrg5+8Hf3/yxaH7R4IU9bNJn+HY8eMICPBHSUlJNUH08vKCL+c651QOxo4diz59+vzoNNeFMipdeXmZ8dXL200khyIH6e1bi6y55cYxihqnjc1zfA4rPM5VUL7EIpMx9/XSkmL4BQSAqkGZ0hnnQTfIRz3DL1dbumzflSim7vj7+hqv7SZd4F/RVV7hoq3amM4D5dSz5OSjCA+PQERkuJ0rKylDTk4OAgID4Ofnx3kPcJ1nf7t27kRsq1YICwujPrj5pgBRg6rSu1rgJluBoPQvMCjQnHFGZsYl0RtPyLj6+NWh3w7bHVY5v4ni4hL4cs5M5wkn09U8VtkvN0xmfLzN7nlrXjxZzzZLOE4fXq/kfb7kYRnlzpdzaKhtmtz8cWiv5Lx4uWVSDkFz677FBZJWTF0sLS2xOTLZICQTQcFBdlz1gPOt4ejbaUTniRKO22yNplDBIuXJ25FdjzZ0KP3x8fFxyZpz7QLh85+E+/g7yMrMxPFjxxHcIBh7dichP5+T41OJzZvWo0VsnHWs8ZRyYr6eMQMJrROwk8K5N2kvjh4+itTUZKSnHcWJ9Aw0adzMbt6zKwkVFO7g0FBkpqSZMB8+kowj/Bw9ctTuCTclqKASlCAl7RiFPxCZmVlYuWINWreOdylopTIqOi5yQDzQSZ3WZGzcuAHhoSFUnkBd4X284mZSASNMOZaw0DBs2pTIdjMRHRVJQQL27zsMf79APudnTuCzSZOQnZ2NmJgYvPLKK2gd1xpR0VGuhmrDOUzGrp27MGHiBAQGBqFFy+bus2fBRU7ymSCF2r1rN5YsXYpbbrkFw4cPR4/uPdC9Rw/06NEdl112mQn23n17cezYMXTs0BFBDdzC7QmbBOJ86bwUY2MbMhQHDx5Ebm4uGkY1dF8g2PaC+d8gqkkTBNJ5VoHPyAjKqKxevQZBQUE2LsfwOFC7mzZtYkBQiIYNG1IugMTETewjEtt27kNMVDS86VF8GCgdp4Peum03Dh0+gEOH9vNziDJ9GOkpyWw3AGERIUbPyeyT2LljB5o2a0aZS8KWXbuQfPgIknnvEX6ysrIQHhFh/Sdu3owJEyaiWdOm1WTPIlT+rS3zMWPFP/JTonX37t3o0qWzja0gv4DjXY1Zs2dTnwvY3xHsYv/79u5D8+bN4E8ezZk9B82atUBwcLC7aTqjzFPYunUX72mK5KNH+dxR0wvH6Dk0SLdms+1A6qwMoQKSn/3sZ66LDkiXwf2MEiWvGueqQdc8zs+eNRst4lvjRGo6dm7eisOHyTPah0MHyW/SJkOqvjdt3GRzkJKSYuOb8fUMCxIlx3pGYxcUZGrcnpgzZw75HQNfOmAfGlpPshRczJ47Bx06djAjXcBgcQ7H3KF9h9N08tsZkmyaAgTxypGv/FOnsHljIpq3aIn1a9cihM5cfVWBD/vSCSXSRm2gPWtC+dVYpk6ZSr3sQZp82SYdlJyS2uafExknbK6DyPvExI28n2M8nI4Mnj906ABy8/PRuHEjZB3PwG7aof37D9g8HqHsHT56hPccREpyCtLS0xAdTbkmrRcS0Dj4rnPxKkLawd2YOHk6Vq5ajvVrliJl7zbMnT8PizfvRtLuHVgw+2ukcWKPpBxHs9i2dBZF+OMff4cR147CouXLcDwzGwE+FK7SIgrwCUyZOgPDh42Cj783lixazMEHoHFUOBLJ3GUrtyGSSlNMp7B3z346msPo3r0rBa6Ux0ewYsVaTtplOJZ+DF9MnolRo4YZmZWVxSgtK7aJ/WzyZE7AJqzjZ+3qb7FqyWI6uCQycAcFbCPbWEFmNTKjs3/ffmzbvgNxdFLH049j0Yw56NG1I/wa+OOl53+PQYOHokFwgDE1ISHBBDCUjnAtBWDwoMEICw8zx2NMp9C4MjqPCTjLXJSWlmLhwoVmwJs0beI+ew6oklQPXPi8V0HRpca5ew+NIoV36NCh5tyVscm4RNDItW3XFv379TclbdqsqZ2rFR70VDN+DnTIc06pTd8ygN+JDM8IMaLGzfz53nvvYvHiRdi4YSOdzAGsX78BGzZswGb+njhhAk7m5zJQSUPeyVOY9fVMOoyN2Jy4GRt434IFC0ypNtOQb9u2DcuXLydP2mD9unWYPWc21qxZy2uJNF6HKdPgs2vQrHkTLF+9CbmnTqJ1q+aMZPOxY+cefDN/NQOpGEayhcwEylBaXITdlLcDxzIoZ52N3JSjyVhDOe3Rvi2mz1iI1NxAGu8T2LNjEyIbNsCqb79FixaxaCRDkJGJjaR14ICBJnsOqkp09uXmspstXgym5BSLigoxceIkjB8/ngaUjoJy6ufrxyz0FP45fTrPj0N8fDyiGkbh008/ZVOV6Ny5C7PYE0hOTUE8eVBZXkQDWoJlS1Yh5dgpBnexmDLlC2Qx6IqMjOT4c82hK8uXQ/F2p/3KjGS0xc/x14+3cwbSqDk3gyi9MfpdY6mmR57QZV6qoMPwoui89vrrGDZ6JGZO/xLFufkoonOXsW/cuDHnaYvJp8Z0mEazpJg2gk5UmbcCVDkeBRLK0KWL0m0FBdOmTcO2rduwhc+vIf8XzJ2FA7RF23bvwY6NK/Ht2vXwjb4M0ZGB2Mn75n7zDXIYJGxetwGbtmymbZqMsoIiBt6JWMpALZq0rFi5Eu0ZnAll7Gv9uvXIy8tDdKNoLFu6DAcYDMXHxeMvf3kDbS9rhwLGy7mn8pB9IsMcfxAD0CZNGiPnZI5lhCpXKhju3r27OU9ff1emJD2qYMZxmHZ7xepVaMW2Jn4xiXYsBD7lAZTFLGTnpOKLqQsweOBwHNqfhOSUVKRRH8SXDRvW0zFlWNZyMuek2VTNXVhY+FkrQGdCLZkLG/YuQ/PYdpwIXwpMBcYMGUhi0hER3xX9enZBFrORO++8B42btsI3i1dg+qRPqXhJSDp6AI2atkDH9l3oTAYgvm0bRIT7YOmSNbj2muuo5J/gq69mcAITcfLYEUaCjRESEYerxw5EQpvWCAmJQcaJ4+jcpSMFrRJ//etfMWfWAnrvtVhMh3Fg33Ea+Q2YM/dLzJs3g8wuQ+MmTTFj1iyMufpadLy8E9qzzyED+lFJOqNVbCtc1qEDvfkeZifRjMaaobCwiP1vRXy79mjVMhY+hUwZfUv5KcMHn07HzbffxezFi8IyBf/4x/sWDUtYtm/bzshtKw3YYnzNCGjRwkUUgHxc3qmjm29uhXEiuVogIZAyyrmMGzeOk9/AfeUsMIVUadJVO9XHTteljOcBMzhMu5syMp5OgxNOgWrVqpVLqNS8uwuVTlq1ijOHo/KZxuL0X1xUbOPWOScy02/VevVdjSc8tHN8VvcqClT54dygcbstTQ1EMasoLCyg0W+GYcOGoU2bBNeHQcSu7TsxbMxoylUX7KZhkKL37tPbykCDhwzBVVeNsHkdP2485/NyZr6b0LJlS8TFxdE4tEf//v2tHNiWMtOoSTNsWrcSfXv3RJuOvRAZHoqwQG94+1XiaPIxTJo4lwYgjZkeM/i9e3FAQQ4NbGjLePTq3QNJG7bg7Xfetuw6ee9OVAbEoOuVNyG2UTj8kIeRowYzk9jD/lvTeEexzaM4zAh01CgGZ25FV3BTWqayjE8VL53SigMvHy8bk5xJ/wH9TXZkkIQ9e/ZY0DT22rEWcPn4+pqD1XGv3r1QRJ04xOvt27dnRM8+S45j8bL16N53OJ1lLttdYFmfoD6lB0VFRZRL4LHHHrXoWY56/rz52Ld/n+nNvHnz7LOWjjqO8qXMz4ILBSFqR39qyEk18LZkBgCvvvIKdpOnKZkZKKTRvXbkKLThnG1jJnj//fcyAEqlDIchoW0CA8oYBpbLrc/9B/aTji10JMdw8MBBpKelo1+/fujYsSN28FkZ1GbNmpuN6NmtCwb17ISs3AL0HHQlLk9oSueyEQEx7dGicST+68+v4cGHf8WANx2zv/oaBxiU7dq9C6fobBSADaX8xSXEY+GiRZzDhpbxFtJZTP/ndOOvKkFvvfUW7rvvft6zBE2iIy0wCgiNBM0eCvNy6Tj8rGqioGkdg5xv6fCWL1+BrVu22rHaFj93MHCJjY3FO2//D1YumYfEbVuRySw7LqEVenfvTdntySChJeepEIuX7sLVlKPi/GzEtopHXHwCEuLj6HhOokXLFhg4cCB1pg2a0qaKHyEhdE7e7jLid1XurDi95mL/Kj8t4XEFdn+7AV708KU+vijLL0R0RGMqZSBKyk8iIMgHJ06UICwmCs3bNENxeg7uuv12fL10Do4cTceWjdsxbuQIIBj46p8foKggCLfecbt18vnEz9GZTqBTry5Y8OUU/OXdqQgj88spZKVFpRgyrC9+/dh9VL6VTBs34+f3/xrFxdnYsmUbli1OwpO//blLKOkEqVuMBIpwOPU4U75ieBdX0GAUwiegEsHeAfT2ZajgTUH04Ko5N2/BCLO4BB988BGuvO5GRgXRCEcx04kcTKODPBEUj5/ecQOqTD55IuHXn9//7vd44sknXKUCKpTS7MmfT8bzLz5vSvvf//0/Vm6YOnXKGSdiyeIljG6/wQsvvkjn4lE7FWpRKNWJKypogDnBmmgHoknRl+q8VjO/gMk3uPsWT2fRST/zzDOYNXsWBTLedcGhTXD3oTKGDPQxKqoM2J49SXjl5Zetvq4Sm4zxm395E0OHDEXf/n1dD/EZ0SxHpGe2bd9G/lOIaZjKaKgefuRXrvvOBZ40CW66JjD6njz5cyvXOPCr8GI0uRmfzv4K7eJaw0drfTRiUp6PPv7YHM6AAQPo7K/H+++/b+UHqYQMuQKK5557zqJ+zUF65kl06jsQaeuXwpvZeopvDHp164hf/mQc2ndpjUWLNyJpB/CTu8fQuOexHwYtlI1t65k9p5zEz39yC1SoTd5/CIk7NmPsiAF476PZiOl9M0Ly9+HEoY248aZRmDBxKoYMGWWZw9SpU5FfkG+lpcLCQnMKy5Yts6jziisGYfiIK40fKinPnz/f1vJkbDp36mxzcNttt6FL1y4uZvA+je2zzz6z8tFDDz5kgUUig72PyYunnnoabWiUMzNzGEDNwtXXXI3GjUKwZcV0/O3jb/D0f76LN199Gi889xRmzpxFvvW3zMeBrU9qbtzzoRLwhx9+iNeZaTh49dVXMWLECHQh37dv326GXuXIAP8AK9H27z+AxzXWlzxlm+3fd8+9eOvjD/H5O+9heN/+CGwag//529/w7G+etMBNwVGffn1wMjsH77zzDn75y4dcgQwdbGBgoGUBqmYUMau8/vrrzanJWUqnxZtHH7gXYfnH8I/ZqzDkp48hPooBWGkFfJiVJR9JwV333IV/zvkaq5ctp53xxeU9eyA4NISmswxLli5hgN0UPehEFtJhpKWm4QZmjp98/Al52Qg3/cdNnL/leP211/DWf/03lq1YhbuuG4Pf0b6MvfdB8r8VQhv42+aT4MAGnB/qPMes8u0OVWI2bWJgf6fNs7Kc+fPnWSB9WfvLkMgAfCkz4l//9hls2LgWJ9NPYuSV1zAyLMczz/4aw0f/BCMG90FB1mFcM/5+NGkRB2/a0azsTNq1EEQy48ujPvqzbS0DxLaKdTP9wnB6BUvQ2kQlTzE9eu/dv+PRp56kwJBpAUFIPpyMkNBmdASBHGsJ01N/zPlyJn716N34f59/iZxiX0z8dBL69uuNzVvWYdz4ESgt0EJ4PgZdcYVLQJhil5QUWERbmXcS4dGN8dsXnkYglbe0uBAh/pXIp9HKyCnCytXrcdWwK5F38iQj2yJ6c6ZrednIzaax9cqjkrh2/vj6BiGMxjXl4CGERDZFDNPRzz95D917D8blnbsycvOhgT69S0i1VdWfZ8yai5///C7SlY/CnAIk7jqKX7z4kJGpYMpkmx9FgAf3H+RzARahydmEhoXasTcdl8oTMvr33nsPbr75P6orQg3IKK/+djUj4HY0WHQsjiKe4Rk/Kty8eQtxknyQcjjQscalhddrrrmGcxPivqJ+6HppPKVE6sKByk8y8Eq3A4JO13fVllLu0aNHc+424+Xfv0zn94LLwdRCmyofUv5ddCQ9e/aks/6QilqMRYymZAAHDRlEZx7OQOMI+vbta2WdHj162niVBcrYtWGG+e677+L6667Hu4zOHmYkqDKFasCtW1PozyWbEW0aID9yWkrjR44caRGpA/9Kb/yVjr+EbSvW9qGRyTh+3CI/RfSpqak0zEfMSUdGaO3N2+bJPjTEw4YOw6233mpz/g/Kd8f+V2BvqDdCvSqRGZmAXz5wO/zL80lLqSnounWLkJG9hyRlGZ8CSWNeVjaadx9ipMrJKlJMP3IYuzcAmVlpWPHlNDTGCQRXZFHvjtJpZ2PwYFf/csJDBg+xscoQi2fPPPsMmjRugkGDB9GRfGJy8NWXXzFyv9/Wa977+3to0aKFZRAKhhwogNC9Bw4csExD1QCVrxRtP/DAA4hv3dro8/fzt/q9ZKiixB8H9u+1cpJKJ0MYMDSKaWRZ0+cMroLd2becxYgRV/FZv6rF9pUrV1mAIUfmGCoFRKJJjkVt3HjTjRzrYDz26GNYvWY1unXrZovu2XTkgYEBVrYyiHnkQVlegenCrqTdljlMmjARfgwSJTtvv/02UjifDz74C9u0UFJSTPtzHG/8+Q3bGGElbGuq0sYtvXHk7FjqMWzfsR23MUj+6JOJuGfsVdQJLzTw9UJJQSkCg/wAZjIL586kM6o0e2A8XLgE31K+tReInTC7SmZGcp/R27ljJ+xkgKLy6t69SbQ3P0N6OvvhnEZFR6OY2WdvympgiLLfIEyYNAkRDZktVBShfZv2uONWBeQE21YJTGtjKumdyjmFqCitv1WavomnKtd+OHkhsk6m4Zt5c9ChQ3ssmL0MI0dfg4JTZeRnMfr07U0SCxEQWI74Nq3xi1/+hk7shGV0MZzTVpwjtadMU/0I4qNTfjtfVH9KaamXTvniljvvoeKlMSs4BN/CfKzfvAelVKYu7WIYfRWjY4drMG70WBxhRoGASMS0vZypWAscPrITjVuEY+aynWjbKo7jD7EUXzsVSgpzyNyj2FjijaiYJvjr+5MweMQwFOflcBJL0Si0CNsPMeILbY4efYYxKynF/j37KQB5SFy/FsfTcijou1BSmcfoxo8GowJdOvVBy+aBeOeNVzH6/kfRmmldq9AwNIlNQAxTPrdoVsPAKwZg9Y7DWL1yNcYMbo6Fy9cgttNgNA9vYMLnmZ1LSFRG6NG9O7O1Exbx3H333TbB4eFhrgiLfyMiI+xzJpRRmBR5PPvbZ+0Zg/yFc1yjX0HlgtGjRls/NcsGllXxU3MnlGr9imyzmaZr0c+BU0ZRyScm6HR0r3PaKVXuVYHHH3scjz3+mBmmquxFcNPj0ChD1a1bVyxZsgRdaViCGgSaY+jNiE2ZXJcuWjdzlfG0nqE6sfrp1r2bZQdaD9GioX7LKChyVHR5MjsbFRTyc3IuboPj4qEXDWcRMjMyKWPpumoIgA+NgCu40G3l5I0UU9nVIRr4KwZeYf2Kz85uIdGs+1TDV5lg7rx5aNG8uZUO5Jg1j8Gc67RSZvlEBZ34hxP+wajPB3GtW6BZMyqmtz8DGy2IlyOTvEnZtReffTgd8d3ikLYrCQUZWcjLDUHjJpHoN/SnGEBWB9l4yqwMFhEZZY5A45EjVuSqTOQKBmrNmjazflV/V5l1xcoV6Hh5R0THRGPfvn1W4lRJ2OGpA41LRlVra5IpOQQZb5Vt7D72z1vsWBm+5lgip3JRu3bBxrfu48Yxq5OTKMY999xt63HWtgw3ZVQ79tSO1hb27duLm+g8pkyZgl49e6Fn757UoQxz5JJBBUba/BPfOh6DBg2iXg5EZMNI21GVySxYaydy6gbSoU1CWWnpJkenTuVY4HvbLbejUYe2rj07lDsZb41JpT6NQ2sHDzFzUQlUm33EA+nBqlWrXOtDbhlStit+qGz+m2efx45N2+AdEIymDSoQ4KUxUSaK82mge2LWovnmANWOSnDaaOEbFIhCzlcOs8xSBudqWuslHdt3wMcffYRHHvm1ZYkFvH7ttWNtDbcFeRAYxqCQ2XQZZfTeRx9Gu4RYhHiXo7xEzrDcgldHLpOSkqx0O3XaVFw39jrL+LSZKoRBzZIV6+Ab0xkd2ran7h+lc+mAUNrYtUl5yE0/gD69hyK4AQdKPnh7l1C28pCRcYw2OdvmSsFDSoqv8ciCcfGFcH9dEE47F/Zb1VJJJvp3b4qP/7kIwZcPQv+IAuQxismM6oaRQ9ri8JblWJK4Ar946E7MnrYeN187FJu+XYgOXfvA268UfQeXY+I7HyFtSxAio5tQCFw7MeZ/s8Ayhw6d2lOIQhhdNcStt43F7C9no0fnDmjbvAGOT1nJdLAZ2sW1RHlBIaLoQbOyDtGbJuHW2x+wRc7oJm1tolwSxVksZLZDJezUhkJG5lUEBDLN/wTekxvAv6wAURRYlQAiGp42/r+6Yxwef/xRVOZejWUrt+BXDz1W6y5LGWktEt98y81Wg9fi36xZs+nZAxHKiMP4do5QW1nZWWjevIUt0smIqJSkxUcJkGrcUmi/AI7NPR+KyCZMmGCOzXn/QMrlxXHLeGuRV6UPzx1S4o2UV9GtW3fsH3uOf2xbogd0XvChk1JUJ2Mz6IpBdq4u+Af6IyYwxhYm7/8ZI+bMLFMcGTZ1umvXTtxx+x2WtaiuLUWXgVRfylwTEzcbfcpwDjLrlBHQDh4ZMNsSfB6Q4di+bRsj2hzb9upZqpFzUbnmMA2qL2lI3rMPS+kQta2+lA5k8uTJFlmrbPr888+7skPSf9VVV6Ero+hMZgJad1Pp0cvHD775BQiNbojiAF8ElNOgFFTAl4b46jE3oLCIDq4okFnhKZzMSWNUPR+PPMxs2IvRZUUoP16IiApESAk7yD2JXiOGYMMnc0hlBYKqnKkvWiUkIOtEpmV+chAysnIKchyPP/64zdHWrduM13J4iuS1g0gylJKaYutmmp+goAZmAKvAaU6lDMrAdmVgoOBBgYBTblW79s3rkhGVXrz8gtGnX1+s370QRdRHL4aWkhYvasvBA4dsLvVkA/al9S71J5q1o6tr165Gu7I/lavkDNLZv9aBtFtSTmbmzJnmoLQ5QNmvHEtefp6147lxpIIGWL+bRkXbWkA3Bi/71jBj4DWXfLuCAwUux4/JwdBoczhaoP7z63+2rcLK5HWPjGdubh7uUKmeD8/8eibmzp1rMurLrE1zkJySgZCwBkh67lmOjbKSn49xV12Nrr16mQwrU5acdmYW3qUPMwIfVQWAFcuWuWwTUcb+ZOSlw+0ua2d8NadPp6FytzcdcoCpdCWKGdhtYjZWUVqAhhWlCI+KQXSjRua0hQ/+8YE5kU6dO5nOvPqnV/Gb3/yGfKLccAyb163CL8bfiEXfzsPoG8fTNkVg/PXDsODLd7F3Xxruf+Ax44evbwiKGL+fOJFK3k9HgJ+X7fBzXo1QUBGibNGYWvV1QfiOFpeXl8AnkKkuDXYZmR0R341OgKkUmRbZKBrN41uh4Ggoio4VoIDjGnHtSISX+6G0IIMKGGRCU1ZKr5i6n2lrNt795DPLWlSTGDt2HKZO+YopZgCCglWW0VbickbYWRSGcBRRgPz9G6C0sBhRIfz2DwLnDJ9N/hJPPf0io9kg20547/0/VSBBxivKksKkIzX9OJbRed0w/lrkUWFuvPEG9GF06sfxBFDRakJp7uhhQ/HCS2/gnb++g7jYRq4LNDicAdcxod0jMtQSDkFGV0r4xRdfmGILtsBKjy/HYOmkYydqQBF1yxYtbWugDGI0I56nnn4Kb731X7Zop0he0XIVOHZN+H3332cKq4zETrudhMROUbQMczXwORlDlTec344D+Q48TisKVk38tddeq3beoM50zvl2Q0ZPhkHbO48dO26Gw+VAm9sip0oHH0z/kMrwRxQz2n3ppZdsy/OypUtxPaPgnTt2YicdkRYTte1SO+haxrY8vb7EOdZzp9eVSIDGIjrcUIQnx3sFI18di08yXBal8vsxGuR8n0pGaUW48sor7eNHmZAhk+N45NePIIwK9fLLL1vwo7U6P39frFy+0rJTbQPWQm9WZja++nwyRo3pi6OMDv0qfZC4YRO8GYwNZpupqScYeMzEzx64iY6oCMuXr0PGiWxG+Rmo8A1jxvoCmkU1RXlOYyTRUZXkllK+fSmjuXQ8dAjuMaWmJOPd997DyFGjkMDMSXLTuFFjcxaSA/H8/ff/jj/84Q8WpWv+9b5SQX4hPvnkEzz55FPWjsoc2snlQOWTfBruxnTqkjM5Fmd3l3il7pWIy/iqXZf8MGL3ViZXjvAGoS4RMDq9OS+l7FdZoU64ggbxPflosjm5m268yeZfTuzulnebDGotK5TR/htv/BkD+g/AmrVraOTvsDKcSpR61UDZsDJuBW9NApq4jDGNuXTBtgazDV92KYoZ2xtNzlZm6cPvfv87PPmESvrBxrenn3naAhs5GAVlYeGh+HbVt7ZgLqg8pi3TjlwfOpyK8f9xJ2679Sbcd9+dtFX+CKSh8WY2WnzqFPmoefNCCXV+0mefIYhzrgJ0ORlzksFI3FOuEqB4IeOv8rmcr+kg/6o8p3K3SohawBfxGkcBx6gsu4Q2S+/CaW4q6HS0fXrixIlYsmyJtRsbF4s//uGP5jg1b5qrxx/9JY7s2IPK0lOIDFOgSf2pKMbi2ZNx810Po2mzxvBV32WVCAxuyOD7I2bNsWazFixcYK9ZqHSp9cXc3FPmyAxunlwIqu8WY0PeUupKet78TKzexSgovhfahgHbN29CZUg4QityaATWIxMN0adHZ0R45clv4MvZszDuuqtxgKnwe2+/gzaXdcItd/wE8+Z9Q0Pha4ZUfNyyZTsNSHM05O8SRbJkwoqly23Hhg+ZHRoRSWFsaczfvnU3Fi1cQsPeHh07dWHaH2OGe/OWRJzKybVdDbm5+fjqyy/x9LPPYfmKlbYWsnnrFgpwBEoo2FkS2pRUqx8rus9kRLhtyxasXLYCu/Yk4Q8v/yc2rFtL752P41TiRkxlJQy6/5tvvrHasGrZilJkQJP2JlkGoh0bSuWVEktwp02fZgJw1cirrDxQG2SEM05kWDqrqE77y7WFU2tSyoi04Od6x6cGbF4YVZEGRV5SLuejiKMuATAn5P7YPZ6fGhAdKvmNGT2mennPZTcsgpIhVlt6Qc5Br169LbOyKK5zZ1t0TExMpHO/0TK01au/5T09bbeeFExlJb0Fft1119k7TuLlDTfcYFGp1gLkdC+jI3d4qJLkiy+8iLHXUfmr4DEAHooPEXQCWm9IT2d2yG8t0sshaUsyiaax90E8FSiYhlgJr3YzacF+U+ImPEFD1JfR+d/f+7tliBpjI86FjJ2M0/oNG9CubVt06tTJoup+A7paMJG4cT/8SGeTxpE0LEV4992/sX2tx53i3KYi75SPrWX0698L/foOoGbT8ZWUY1/SXhrmMnTq0pc07EbOiWSUF5UiPSUNaexzO3WhDaN5rWdpx5kcQUJ8gq1PTJs6DdtI99ix19n6hGQptmUsZX855TPdZOqWW2+xklMIjbPKL7179rb3dCZ9NsmlZ5QllUIS2iRUZYlaWxQz9a/eaVMW2JXz6c8MrbQ8jUHWEbRv0xcNY7QwDsuclC0pMlcGGsX5ValIrwAoC+jTt4/xS7zcu3ef0aZtydqQoKwwgtF39slsDGVWs49jVMSs8p8MpWhfv369zYle9NTrA2+++SZ6Uo52U2Y2Meu9jvKwn+3qPTyJqHRcTk1rHirpKjPTOpSCjjff/IstfEt+lUG9/qfXjQ5lFMlHkpl9dDbnrEBS7+lNnzYVH33wdwYfRZg142t7J8XP1992fskxfDVjBsaMudocVMeOHTDuhvEmbwP5UZCitc2Y6Bg0CA62bEwyrHUlyamIlc5qZ6Y2E9h6hncFlsybh6FXj0fPLu0tMItoGGn37ty1i3L5Hv702p/M2cq+JVF+Uum8pS96Z6pHz+60WV7IPJWNIylH0I18X7NqJSZPmow773kQOfmldPbJ8OYYyulwtGP3g//3CaZMmYYZHIvWQZcvW26BuzYk6ZzsTIvmLU2+LhS1/6/IStWKMrF+VzKdS1e0i6jAtEkfojQ4Bk0jg3D0wD70GTGe0X5TBIDRUF4RnnnmOdxzzz1Ys2YN+lKwtIArUU2iAd+ydSuFJoTGeCA+/PAj9O3Tz4y4oiopt7KBaBr+r7/+ioP3ofIOYXq53JxHh/btER/fzKIlWw4ibdreuJkGTMIhBSopLsLo0aM4UYG2cLZmzWoqRRD7DGPEVWrRrMotSif3UPBUNujaqTNaJsQx1KmwnRla1Ny8dTP7r7AsRQqRmpZqxkECrtqn9qVrC6lYplKW3gmRcCiD0jsFckhnfMmSULQsyIH9hQqjKFoTuX37Dos0BwykEfoh4J51jUWKpYVjbZG8/PLLZWOs9qzoUZmIFsNTGE2r1HHrbbfauy4OFPloHgTxSbxUScNZBHzphZdsXUW8Uj1dfNSiYQNmplosFB+8aaDl3CTUiga1G0Y8EfRCn3YSDR8x3OhyEW4HIp6HXnRWWUzxvzYjrNKiMk21K5QVl5qhXLR0iZVTtGVZaw6+HJdeKJOBVl1fkbmgFyoP7N9Ph9DPshW9CyN5u+WWW2k05yMyoiEGD+lj2dS0aXNpTIJp+JvbltJ+ffui4+UdGI2vYhCRxUwQdJp6N0VrM4xUOY6EhNa2BhXfOg5taUwPHDyAdXRewXS8Bg6NQ7J1Qa2hiL/O+pNlGgz+9D6ComKxQltsZbRHXjXSFui1YeS555+zLCeP+vVHZqKPPKzIlXMmtrnnXbAFbvc5ZQFyBJr3uXPnke5IXMnMvhKMpIvTMHXKUurpOLRqHUEdgb3M14fjjWMU7cRS2gY8cdIEjBg+nHPe3aJ20bxo8SLLYJQZD2Wb2hCgMak/GVxlPxbFu9cO55PPe5L22Ps5LRigyEbIycoYa3uzXo4ec/Vocybr1qw1XlS4ZUHvtWg7ueRAxl/PyvnJ2bTv0N5eol66bJnZBMm9NkXIGankqAxQu+z0vAU3bFLls/UMPmULunXtZmtwTzzxBJ586mmsZDCmMchZSJ70jGyanIs2lbRu3drGL0ceGUVn4YbOPfLwI+b0AhswI6fB//PvX8GIOx9Eh4Rm0JYpjSXlaApmMisaPnwE2rJfBdYKjtdvWG+2RpmcSs7aAMPbGbTsofyussBcfcoBNo+Nl5K65PrgIdrBhgwulOFJb8uNfs2NMnwdO1Ab9l6fye+Foe7/cl9n2ZcKMTSfHFQa/CiwwSER5nt8Kkp4llQyK1FtMY0KEx4RbhOtlMr13yHovy2AMUKRudI8RQ9afFSEoJqxIlaVg9RfUVEJJ/kAGjaM4TOFVL5Ye77KlrBjZfEqDUgo5bjC6EBUI1REqvtMAU1w9eEzHJ4mXTVblWskfM0YGVAaSbse4LdCMbZXUVbGlPiwtVVUUuIyUvbfX/A23iOB1PjUv6MYVXC46Ka1VtS4R28QaxFRKb/eZ1C9+GybAi4ZSIs5FgqWAgK9t6PIUDQo4j9FQ5BPARU/9FKiFhOlMI8++tjpLdSE+F3FB2d8gnuM2o4q51C184z3aG5VPhQ/nWdVWtSaixxTLPnulBb37d1v8hFB2XK16dkJwTHIWSmrjKMyax3Ky0pqGqAuq2wiOXX9Fz6CFE+K79BYDXxGY1d5xXMdS+ePpZ9Ao8YxJoOnISNVwms55F0jBARJY3SD3j1heKW6rQc4ZDoI6hX9sfO/Ihk8h+WmSy//WSnI7WgNus+D7l07d+Odd96293CU3SorUNZl97EfOUcZ/TFjRhsvapNZMwH8Kz1RlK1dXnKmIaEuh1dRXsyM7iT1OhLBof7UEyAjQ6XsUAZxdAqkRzw5laNF4uOIbxNvma7aqyYfnnCPw/O6dFT6lpqcity8XNM/GXwt0usB56VjZx3CHADh6YBVRZCO5jFY0bqfjrt07nL6v4vRUPm8tiHLqbi2nsvYltoWbqOLTTl9qMQq55F3Ko82oNx2uklOlfU5YxC0nmNVBAfqhx9tDLEAlPLtyg7VNizA0YK+a6txCTJS0+AV2QoRZLkPbad2qqmyI3tVZeDZnvWjNR3XcDl4ftzHxYV5lJk8Zu6Zth1bY2enrhuoE5InZbBRMR4BsMcYqh070Dg0l27azwe1OBdRy4aoGKLJcS46r9jOxsGfFSRaZ7UE7BpbHZ3rpho9XBDU/Lm04ybDGOLhiavgnPKkSyyQoOq7tmccaKCeNHje6rTnYkZ1ONd0vz7O8Y8J0qCpl+PX+lFWlqv+fCYMGTIEvfv0cv9ywfjsVvIzouaYL5gHetCNmvPF3xpTlUGvupX3XFBfNaEG6RUqnbUsZWzKINyN25e0xK3xlW6DVlvfnuP3GJLhUtDqGB1+q7ynrK3WKFR986PykQz03j17bc1L24CdcRl5xms2xjHpUPBkfZXDdZ+r0r/axuJ+/pzGqXtr3uc8L9TVhsdznobdzutTm55eCGqjzxMOrbrH85hwW1rXTzK1gvzS2pGXvE+NoMSg5/VxLlVrT605N7gv2JxJBp1e7MYfDLU4F4cwEsK/UhWXG5Fz0QoJj/m33IjWODlxdlQH4XrU3eT/OlTRpQPRz+/vw7nUhNPd/waQlguKTOriQ124ZGP26NiUx6NRu+RxrupW/r4kfQsysO5JNufCY8eJCHoJ2TpzR4we5NYK3VrznoulVe3p4ymL+l1Xu07/vG7/yWEAnad4W9O5uBvVv9Wa4o+qaXC+v3OTB6xBoq7rnqitHed54VzakN11eKFn9fHkzQ8Fj3E7ZAhVpDgn7IA31TVu53y13/rhfDRg96Hk0K7rn5oNfr/4LoslUPyUMUqR6tguDBMyRjZGHB9hmKIHTz9cB9E2uH8FuOk/k2NxoFucz8XgfwtvOI4LSXnPmw8X0MXZUaPRmn3Yb/5zKfo2PdDHU2V0TC3xKjp93XJ9Rzv0cqXrXZg68X3JQa28qAO65r5ujsXgPkED4KXx2U+Oid+mJu5v++iSJ842pu888D2jZn8/dP91QGToY+yq4pkH886Xj27bfVr+9HF6+eFR95oLIf930aSdjUH/SjgTI5xx/jjzWI8fAjVl2SIvt5ZUZTNO8ViC4D42hT8PXCoZ8qRXbXpG8GdDtbFyHJ7ZmVCTxktF8/8ReLJXx2KflcMM9usCeOq0WnPiz7uhS4IzOpdLgu+39R8WP84c1eNfHeerA5dazpz+z6ddPePc79g8TzjX6nXionF6eshoC0T4uWC+1iZsP84k1TuX80G9ItXjUqIu3bjUcnbael0Yfig66/FvhXrncj6oV6Z6XErUG+16/BvjXCuw9ahHPepRj3qcM+qdSz3qUY961OOSo9651KMePxbqy1/1+DfG97/mUo961KMe9fg/h/rMpR71qEc96nHJUe9c6lGPetSjHpcYwP8HVOzlUOoiacQAAAAASUVORK5CYII="
    }
   },
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image.png](attachment:image.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型的训练\n",
    "训练流程：<br/>\n",
    "- 1、实例化模型，设置模型为训练模式\n",
    "- 2、实例化优化器，实例化损失函数\n",
    "- 3、获取、遍历dataloader\n",
    "- 4、梯度置为0\n",
    "- 5、进行前向传播\n",
    "- 6、计算损失\n",
    "- 7、进行反向传播\n",
    "- 8、更新参数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里使用GPU去跑MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch:0 [0/60000 (0%)]\tLoss:2.368049\n",
      "Train Epoch:0 [12800/60000 (21%)]\tLoss:0.425024\n",
      "Train Epoch:0 [25600/60000 (43%)]\tLoss:0.386299\n",
      "Train Epoch:0 [38400/60000 (64%)]\tLoss:0.320683\n",
      "Train Epoch:0 [51200/60000 (85%)]\tLoss:0.381233\n",
      "Train Epoch:1 [0/60000 (0%)]\tLoss:0.199081\n",
      "Train Epoch:1 [12800/60000 (21%)]\tLoss:0.255350\n",
      "Train Epoch:1 [25600/60000 (43%)]\tLoss:0.302886\n",
      "Train Epoch:1 [38400/60000 (64%)]\tLoss:0.301833\n",
      "Train Epoch:1 [51200/60000 (85%)]\tLoss:0.168701\n",
      "Train Epoch:2 [0/60000 (0%)]\tLoss:0.228473\n",
      "Train Epoch:2 [12800/60000 (21%)]\tLoss:0.233606\n",
      "Train Epoch:2 [25600/60000 (43%)]\tLoss:0.167221\n",
      "Train Epoch:2 [38400/60000 (64%)]\tLoss:0.104043\n",
      "Train Epoch:2 [51200/60000 (85%)]\tLoss:0.191910\n",
      "Train Epoch:3 [0/60000 (0%)]\tLoss:0.210364\n",
      "Train Epoch:3 [12800/60000 (21%)]\tLoss:0.189229\n",
      "Train Epoch:3 [25600/60000 (43%)]\tLoss:0.098104\n",
      "Train Epoch:3 [38400/60000 (64%)]\tLoss:0.074856\n",
      "Train Epoch:3 [51200/60000 (85%)]\tLoss:0.178567\n",
      "Train Epoch:4 [0/60000 (0%)]\tLoss:0.202281\n",
      "Train Epoch:4 [12800/60000 (21%)]\tLoss:0.166148\n",
      "Train Epoch:4 [25600/60000 (43%)]\tLoss:0.139287\n",
      "Train Epoch:4 [38400/60000 (64%)]\tLoss:0.075929\n",
      "Train Epoch:4 [51200/60000 (85%)]\tLoss:0.088468\n",
      "Train Epoch:5 [0/60000 (0%)]\tLoss:0.146214\n",
      "Train Epoch:5 [12800/60000 (21%)]\tLoss:0.130639\n",
      "Train Epoch:5 [25600/60000 (43%)]\tLoss:0.173481\n",
      "Train Epoch:5 [38400/60000 (64%)]\tLoss:0.165874\n",
      "Train Epoch:5 [51200/60000 (85%)]\tLoss:0.100924\n",
      "Train Epoch:6 [0/60000 (0%)]\tLoss:0.101145\n",
      "Train Epoch:6 [12800/60000 (21%)]\tLoss:0.061559\n",
      "Train Epoch:6 [25600/60000 (43%)]\tLoss:0.147568\n",
      "Train Epoch:6 [38400/60000 (64%)]\tLoss:0.054544\n",
      "Train Epoch:6 [51200/60000 (85%)]\tLoss:0.161857\n",
      "Train Epoch:7 [0/60000 (0%)]\tLoss:0.059647\n",
      "Train Epoch:7 [12800/60000 (21%)]\tLoss:0.058544\n",
      "Train Epoch:7 [25600/60000 (43%)]\tLoss:0.059528\n",
      "Train Epoch:7 [38400/60000 (64%)]\tLoss:0.098005\n",
      "Train Epoch:7 [51200/60000 (85%)]\tLoss:0.030139\n",
      "Train Epoch:8 [0/60000 (0%)]\tLoss:0.039722\n",
      "Train Epoch:8 [12800/60000 (21%)]\tLoss:0.068010\n",
      "Train Epoch:8 [25600/60000 (43%)]\tLoss:0.224613\n",
      "Train Epoch:8 [38400/60000 (64%)]\tLoss:0.042400\n",
      "Train Epoch:8 [51200/60000 (85%)]\tLoss:0.147454\n",
      "Train Epoch:9 [0/60000 (0%)]\tLoss:0.077467\n",
      "Train Epoch:9 [12800/60000 (21%)]\tLoss:0.103520\n",
      "Train Epoch:9 [25600/60000 (43%)]\tLoss:0.093393\n",
      "Train Epoch:9 [38400/60000 (64%)]\tLoss:0.099850\n",
      "Train Epoch:9 [51200/60000 (85%)]\tLoss:0.039424\n"
     ]
    }
   ],
   "source": [
    "from torch.optim import Adam\n",
    "import os\n",
    "# 实例化模型\n",
    "model = MnistNet().to(device)\n",
    "# 实例化优化器\n",
    "optimizer = Adam(model.parameters(), lr=0.001)\n",
    "\n",
    "if os.path.exists(\"./06-model/model.pkl\"):\n",
    "    model.load_state_dict(torch.load(\"./06-model/model.pkl\"))\n",
    "    optimizer.load_state_dict(torch.load(\"./06-model/optimizer.pkl\"))\n",
    "                                        \n",
    "def train(epoches, mode = True):\n",
    "    '''\n",
    "    实现训练的过程\n",
    "    '''\n",
    "    model.train(mode=mode)\n",
    "    for epoch in range(epoches):\n",
    "        data_loader = get_dataloader(train=mode)\n",
    "        for idx, (img, label) in enumerate(data_loader):\n",
    "            output = model(img.to(device))\n",
    "            loss = F.nll_loss(output, label.to(device))\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if idx % 100 == 0:\n",
    "                # 这里保存模型应该是需要写模型的训练次数的吧\n",
    "                torch.save(model.state_dict(), \"./06-model/model.pkl\")\n",
    "                torch.save(optimizer.state_dict(), \"./06-model/optimizer.pkl\")\n",
    "                print('Train Epoch:{} [{}/{} ({:.0f}%)]\\tLoss:{:.6f}'.format(\n",
    "                    epoch, idx * len(img), len(data_loader.dataset),\n",
    "                    100.*idx / len(data_loader), loss.cpu().item()))\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    train(10)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型的保存和加载\n",
    "### 1、模型的保存\n",
    "torch.save(保存的模型/优化器， 保存的路径)<br/>\n",
    "```torch.save(mnist_net.state_dict(), 'model/mnist_net.pt') # 保存模型的参数 \n",
    "torch.save(optimizer.state_dict(), 'results/mnist_optimizer.pt')  # 保存优化器的参数```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2、模型的加载\n",
    "模型/优化器.load_state_dict(torch.load(路径))\n",
    "\n",
    "```\n",
    "mnist_net.load_state_dict(torch.load(\"model/mnist_net.pt\"))\n",
    "optimizer.load_state_dict(torch.load(\"results/mnist_optimizer.pt\"))```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 模型的评估"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这边使用GPU需要注意一下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "平均准确率， 平均损失 0.9585 0.14125641956925392\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "def test():\n",
    "    model.eval()\n",
    "    loss_list = []\n",
    "    acc_list = []\n",
    "    test_datdaloader = get_dataloader(train=False, batch_size=TEST_BATCH_SIZE)\n",
    "    for dix ,(img, target) in enumerate(test_datdaloader):\n",
    "        with torch.no_grad():\n",
    "            output = model(img.to(device))  # [batch_size, 10]\n",
    "            cur_loss = F.nll_loss(output, target.to(device))\n",
    "            loss_list.append(cur_loss.cpu().item())\n",
    "            # 计算准确率\n",
    "            # output [batch_size, 10]  target[batch_size]\n",
    "            pred = output.max(dim=-1)[-1]  # 这边是output.max(dim=-1)是取每一行的最大值，返回的是一个元组(最大值，对应元素的下标)\n",
    "            cur_acc = pred.cpu().eq(target.cpu()).float().mean()  # pred.cpu().eq()这里先是一个bool类型，然后使用float()产生浮点型，最后求均值\n",
    "            acc_list.append(cur_acc)\n",
    "    print(\"平均准确率， 平均损失\",np.mean(acc_list),np.mean(loss_list))\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    \n",
    "    test()\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
